classdef logistic
    %LOGISTIC Construct a logistic loss function
    %   Detailed explanation goes here
    
    properties
        C;  % the data matrix n_samples*p_features
        b;  % the labels
        mu;
        % f; % function handle
        % g; % gradient handle
        % H; % Hessian handle
    end

    methods(Static)
        function out = phi(t)
            % logistic function phi(t) = 1/(1+exp(-t))
            out = 0*t;
            idx = t>0;
            out(idx) = 1./(1+exp(-t(idx)));
            out(~idx) = exp(t(~idx))./(exp(t(~idx))+1);
        end

        function out = logit(t)
            % log of logistic function log phi(t) = -log(1+exp(-t))
            out = 0*t;
            idx = t>0;
            out(idx) = -log(1+exp(-t(idx)));
            out(~idx) = t(~idx) - log(1+exp(t(~idx)));
        end

        function out = logsumexp(t)
            % out = log(sum(exp(t_j)))
            t_max = max(t);
            out = log(sum(exp(t-t_max)));
            out = out + t_max;
        end
    end
    
    methods
        function obj = logistic(C,b,mu)
            %LOGISTIC Construct an instance of this class
            %   Detailed explanation goes here
            obj.C = C;
            obj.b = b;
            obj.mu = mu;
            [m,p] = size(C);
            grad_0 = obj.grad(zeros(p,1));
            obj.C = C - grad_0';
        end
       

        function out = loss(obj,x)
            % the loss function 
            % f(x) = \log(\sum_{j=1}^m e^{c_i'*x-b_j})
            out = sum(obj.logsumexp(obj.C*x-obj.b)) + 5e-5*norm(x)^2;
        end

        function out = grad(obj,x)
            % computing the gradient
            y = obj.C*x-obj.b;
            %pi = softmax(y);
            y = y - max(y);
            pi = exp(y)/sum(exp(y));
            out = obj.C'*pi + 1e-4*x;
        end

        function out = hessian(obj,x)
            % computing the Hessian
            y = obj.C*x-obj.b;
            %pi = softmax(y);
            y = y - max(y);
            pi = exp(y)/sum(exp(y));
            grad_x = obj.grad(x);
            [m,p] = size(obj.C);
            out = obj.C'*diag(pi)*obj.C - grad_x*grad_x' + 1e-4*eye(p);
            % out = obj.X'*diag(z)*obj.X/n+obj.mu*eye(p);
        end
    end
end

